Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support deepspeed sequence parallel #31525

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

zeyugao
Copy link

@zeyugao zeyugao commented Jun 20, 2024

What does this PR do?

Support the sequence parallel with Deepspeed-Ulysses.

I have tested the training on starcoder2-3b. The loss decreases normally.

CleanShot 2024-06-21 at 00 52 50@2x

Requires huggingface/accelerate#2877

I have made massive modifications to the original implementation of Deepspeed-Ulysses to support batch size dim in layers.py. It uses all_to_all_single instead of all_to_all like https://github.com/InternLM/InternEvo/blob/a61d391df96c5f5c243cdea32a5044b70d6fe33e/internlm/core/parallel/comm/isp.py#L628 for better performance. I have left some comments to help the future understanding. Use all_to_all_single is too complex to support other scatter idx and gather idx

Currently, flash attn and sdpa for llama and mistral are tested. flash attn for starcoder is also tested, the sdpa for starcoder is not supported.

It requires a special dataloader (I have made in Trainer) and data collator (with example followed). In data collator, the sequence should be divided into multiple sub-sequences. The following is an example of sub-sequences processing in the data collator.

            seq_parallel_world_size = mpu.get_sequence_parallel_world_size()
            seq_parallel_world_rank = mpu.get_sequence_parallel_rank()

            seq_length = input_ids.size(1)
            sub_seq_length = seq_length // seq_parallel_world_size
            sub_seq_start = seq_parallel_world_rank * sub_seq_length
            sub_seq_end = (seq_parallel_world_rank + 1) * sub_seq_length

            # There is no kv cache when training
            past_key_values_length = 0

            position_ids = torch.arange(
                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long,
            )
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)

            batch = dict(
                input_ids=input_ids[:, sub_seq_start:sub_seq_end],
                labels=labels[:, sub_seq_start:sub_seq_end],
                position_ids=position_ids[:, sub_seq_start:sub_seq_end],
                attention_mask=(input_ids != self.tokenizer.pad_token_id),
            )

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@muellerzr and @SunMarc

@zeyugao zeyugao marked this pull request as draft June 21, 2024 02:25
@zeyugao zeyugao marked this pull request as ready for review June 21, 2024 10:21
@fan-niu
Copy link

fan-niu commented Jun 27, 2024

Great, can you provide an example of data processing based on sequence paralleler? thanks

@zeyugao
Copy link
Author

zeyugao commented Jun 27, 2024

The dataset and sampler are handled in the Trainer

https://github.com/huggingface/transformers/pull/31525/files#diff-ed55888e6665791fe92cc8fc0c499da54f4ace6738551cd9a2591881cda076deR847-R855

The data collator example is accidentally deleted when editing

            seq_parallel_world_size = mpu.get_sequence_parallel_world_size()
            seq_parallel_world_rank = mpu.get_sequence_parallel_rank()

            seq_length = input_ids.size(1)
            sub_seq_length = seq_length // seq_parallel_world_size
            sub_seq_start = seq_parallel_world_rank * sub_seq_length
            sub_seq_end = (seq_parallel_world_rank + 1) * sub_seq_length

            # There is no kv cache when training
            past_key_values_length = 0

            position_ids = torch.arange(
                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long,
            )
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)

            batch = dict(
                input_ids=input_ids[:, sub_seq_start:sub_seq_end],
                labels=labels[:, sub_seq_start:sub_seq_end],
                position_ids=position_ids[:, sub_seq_start:sub_seq_end],
                attention_mask=(input_ids != self.tokenizer.pad_token_id),
            )

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Aug 18, 2024
@SunMarc SunMarc reopened this Aug 26, 2024
@ArthurZucker ArthurZucker added Feature request Request for a new feature DeepSpeed labels Aug 27, 2024
@ldh127
Copy link

ldh127 commented Sep 22, 2024

how long time this pr merge, when can it finish ? ...

@LysandreJik
Copy link
Member

cc @SunMarc if you have the bandwidth to take a look!

@glowwormX
Copy link

@zeyugao I carefully read your pull requests for transformers and accelerate, and pulled your code to try training. Now I have encountered a problem: when entering DistributedAttention, the q, k, v before _SeqAllToAll.apply are not [b, s/p, n, h], but still [b, s, n, h]. I checked the modified parts of the data processing, such as accelerate/data_loader.py and transformers/trainer.py, but did not find any relevant processing code. So, may I ask where the sequence splitting is done?

@zeyugao
Copy link
Author

zeyugao commented Oct 8, 2024

@glowwormX It is in the pr description

image

@glowwormX
Copy link

@zeyugao My God, I missed it, I thought there was this code in pr. Thank you for replying.

@glowwormX
Copy link

@zeyugao Have you compared the loss of sequence parallel? After a fixed seed is added to DistributedSampler, the training data is the same. Modify the trainer.py:

        if is_accelerate_available() and mpu.sequence_parallel_is_enabled():
            assert self.args.group_by_length is False, "Group by length is not supported with sequence parallel."
            return DistributedSampler(
                dataset=self.train_dataset,
                num_replicas=mpu.get_data_parallel_world_size(),
                rank=mpu.get_data_parallel_rank(),
                shuffle=True,
                seed=42
            )

However, when the same data is calculated, the average loss value after sequence parallel is different from the loss value without sequence parallel.

In addition, what is the reason why starcoder does not support sdpa? I am trying to modify qwen2 and I do not know if it does not support sdpa.

@zeyugao
Copy link
Author

zeyugao commented Oct 19, 2024

@glowwormX The main reason should be that it need to use custom loss calculation, otherwise there are some tokens (in the head and tail of each subsequence) not contributing to the final loss: https://github.com/microsoft/DeepSpeed/pull/5774/files#diff-13f25bb51b0f4019d8cb09c07204a33510dca5dccfae736baf10134f893704d5

the reason why starcoder does not support sdpa

I do not have much spare time to make the shape correct when using sdpa for startcoder2 at that time

@ronald-d-rogers
Copy link

@zeyugao: Your implementation does not use this loss function right? It still works ok even so?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
DeepSpeed Feature request Request for a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants